A2C (Advantage Actor-Critic) — Low-Level PyTorch Implementation (CartPole-v1)#

A2C is an on-policy actor-critic algorithm:

  • the actor learns a policy \(\pi_\theta(a\mid s)\) (how to act)

  • the critic learns a value function \(V_\phi(s)\) (how good a state is)

  • the actor is trained with advantages (“better than expected” signals)

This notebook builds the math carefully, then implements A2C with minimal PyTorch (no RL libraries, no high-level training abstractions), using a vectorized Gymnasium environment for synchronous rollouts.


Learning goals#

By the end you should be able to:

  • derive the A2C update from the policy gradient theorem

  • explain why the baseline (critic) reduces variance

  • implement GAE(\(\gamma,\lambda\)) and n-step bootstrapped returns

  • train an A2C agent on CartPole-v1 and visualize learning with Plotly

  • map the concepts to Stable-Baselines3 A2C hyperparameters

Notebook roadmap#

  1. A2C intuition + what “advantage” means

  2. Mathematical formulation (LaTeX)

  3. Low-level PyTorch implementation (actor + critic)

  4. Training on CartPole with vectorized rollouts

  5. Plotly diagnostics (returns, losses, policy/value slices)

  6. Stable-Baselines3 A2C reference + hyperparameters

import math
import time

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

try:
    import gymnasium as gym
    GYMNASIUM_AVAILABLE = True
except Exception as e:
    GYMNASIUM_AVAILABLE = False
    _GYM_IMPORT_ERROR = e

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e


pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

assert GYMNASIUM_AVAILABLE, f"gymnasium import failed: {_GYM_IMPORT_ERROR}"
assert TORCH_AVAILABLE, f"torch import failed: {_TORCH_IMPORT_ERROR}"

print('gymnasium', gym.__version__)
print('torch', torch.__version__)
gymnasium 1.1.1
torch 2.7.0+cu126
# --- Run configuration ---

# Keep FAST_RUN=True for a quick demo.
# For a more reliable "solve", set FAST_RUN=False.
FAST_RUN = True

ENV_ID = "CartPole-v1"  # discrete actions, small continuous state
SEED = 42

# A2C is usually run with multiple envs in parallel.
N_ENVS = 8 if FAST_RUN else 16

# Rollout horizon per env (A2C commonly uses small n_steps).
N_STEPS = 5

# Total interaction budget
TOTAL_TIMESTEPS = 30_000 if FAST_RUN else 200_000

# Core RL hyperparameters
GAMMA = 0.99
GAE_LAMBDA = 1.0  # 1.0 => classic advantage w/ n-step bootstrapping

# Loss weights
ENT_COEF = 0.01
VF_COEF = 0.5

# Optimization
LR = 7e-4
MAX_GRAD_NORM = 0.5
RMSPROP_EPS = 1e-5

# Optional: normalize advantage each update
NORMALIZE_ADVANTAGE = True

# Network
HIDDEN_SIZES = (128, 128)

# Logging
LOG_EVERY_UPDATES = 50
RETURN_SMOOTHING_WINDOW = 50

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device', DEVICE)
device cpu
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:

CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)

1) A2C intuition: actor + critic + advantage#

Actor#

The actor is a stochastic policy \(\pi_\theta(a\mid s)\).

  • It outputs a distribution over actions.

  • We sample actions from that distribution to explore.

Critic#

The critic is a value function \(V_\phi(s)\).

  • It predicts the expected discounted return from state \(s\).

  • It is trained via regression to match a bootstrapped return target.

Advantage#

The advantage measures how much better an action did compared to what the critic expected:

\[ A(s_t, a_t) = Q(s_t, a_t) - V(s_t). \]

If \(A(s_t,a_t)\) is positive, the action was better than expected, and the actor should increase its probability.

Why “A2C”?#

A2C is the synchronous version of A3C:

  • A3C: many workers update parameters asynchronously.

  • A2C: many workers collect experience in parallel, then we do a single synchronized update.

In practice, A2C typically uses a vectorized environment and batches data as:

\[ \text{batch size} = n_{\text{env}} \times n_{\text{steps}}. \]

2) Mathematical formulation (policy gradient + baseline)#

We model the environment as an MDP \((\mathcal{S}, \mathcal{A}, P, r, \gamma)\).

Return#

The discounted return from time \(t\) is:

\[ G_t = \sum_{k=0}^{\infty} \gamma^k r_{t+k}. \]

Objective#

We want to maximize expected return:

\[ J(\theta) = \mathbb{E}_{\tau\sim\pi_\theta}\left[\sum_{t=0}^{\infty} \gamma^t r_t\right]. \]

Policy gradient theorem#

A standard form is:

\[ \nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, Q^{\pi_\theta}(s_t, a_t)\right]. \]

Baseline (variance reduction)#

We can subtract a baseline \(b(s_t)\) without changing the expectation:

\[ \mathbb{E}[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, b(s_t)] = 0. \]

Choosing \(b(s_t)=V_\phi(s_t)\) yields the advantage form:

\[ \nabla_\theta J(\theta) = \mathbb{E}_{\pi_\theta}\left[\nabla_\theta \log \pi_\theta(a_t\mid s_t)\, A_t\right], \quad A_t \approx \hat{A}(s_t,a_t). \]

Bootstrapped n-step return#

With a rollout horizon \(T\) (a.k.a. n_steps), we use a bootstrapped target:

\[ \hat{R}_t = \sum_{k=0}^{T-1-t} \gamma^k r_{t+k} + \gamma^{T-t} V_\phi(s_T). \]

Generalized Advantage Estimation (GAE)#

GAE defines the TD residual:

\[ \delta_t = r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t) \]

and computes advantages with an exponentially-weighted sum:

\[ \hat{A}_t^{\mathrm{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l\, \delta_{t+l}. \]
  • \(\lambda=1\) recovers the classic (higher-variance) advantage.

  • smaller \(\lambda\) reduces variance but increases bias.

Loss functions (minimization form)#

Actor loss (to maximize expected advantage):

\[ \mathcal{L}_{\text{actor}}(\theta) = -\mathbb{E}\left[\log \pi_\theta(a_t\mid s_t)\, \hat{A}_t\right]. \]

Critic loss (value regression):

\[ \mathcal{L}_{\text{critic}}(\phi) = \frac{1}{2}\,\mathbb{E}\left[(\hat{R}_t - V_\phi(s_t))^2\right]. \]

Entropy bonus (encourage exploration):

\[ \mathcal{L}_{\text{entropy}}(\theta) = -\mathbb{E}\left[\mathcal{H}(\pi_\theta(\cdot\mid s_t))\right]. \]

Total loss:

\[ \mathcal{L} = \mathcal{L}_{\text{actor}} + c_v\,\mathcal{L}_{\text{critic}} + c_e\,\mathcal{L}_{\text{entropy}}. \]
def make_vec_env(env_id: str, n_envs: int, seed: int) -> gym.vector.SyncVectorEnv:
    env_fns = [lambda: gym.make(env_id) for _ in range(n_envs)]
    env = gym.vector.SyncVectorEnv(env_fns, autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
    env.reset(seed=[seed + i for i in range(n_envs)])
    return env


env = make_vec_env(ENV_ID, N_ENVS, SEED)

obs_space = env.single_observation_space
act_space = env.single_action_space

assert isinstance(act_space, gym.spaces.Discrete), "This notebook's implementation uses discrete actions (Categorical)."

OBS_DIM = int(np.prod(obs_space.shape))
N_ACTIONS = int(act_space.n)

print('obs_space', obs_space)
print('act_space', act_space)
print('OBS_DIM', OBS_DIM, 'N_ACTIONS', N_ACTIONS)
obs_space Box([-4.8       -inf -0.4189    -inf], [4.8       inf 0.4189    inf], (4,), float32)
act_space Discrete(2)
OBS_DIM 4 N_ACTIONS 2

3) Actor-Critic network (low-level PyTorch)#

We use a shared MLP trunk, then two heads:

  • actor head outputs logits for a categorical distribution

  • critic head outputs a scalar value \(V(s)\)

This is not the only architecture (you can also use separate networks), but it’s a common and effective baseline.

class ActorCritic(nn.Module):
    def __init__(self, obs_dim: int, n_actions: int, hidden_sizes: tuple[int, int] = (128, 128)):
        super().__init__()

        h1, h2 = hidden_sizes
        self.fc1 = nn.Linear(obs_dim, h1)
        self.fc2 = nn.Linear(h1, h2)

        self.actor = nn.Linear(h2, n_actions)
        self.critic = nn.Linear(h2, 1)

    def forward(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # obs: (B, obs_dim)
        x = torch.tanh(self.fc1(obs))
        x = torch.tanh(self.fc2(x))
        logits = self.actor(x)            # (B, n_actions)
        values = self.critic(x).squeeze(-1)  # (B,)
        return logits, values


def sample_actions_and_logp(logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Low-level categorical sampling without torch.distributions.

    Returns:
      actions: (B,) int64
      logp:    (B,) log-prob of sampled action
      entropy: (B,) categorical entropy
    """
    log_probs = F.log_softmax(logits, dim=-1)  # (B, A)
    probs = log_probs.exp()

    actions = torch.multinomial(probs, num_samples=1).squeeze(-1)  # (B,)
    logp = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
    entropy = -(probs * log_probs).sum(dim=-1)

    return actions, logp, entropy


@torch.no_grad()
def policy_action_probs(logits: torch.Tensor) -> torch.Tensor:
    return F.softmax(logits, dim=-1)

4) GAE implementation#

We compute advantages backwards in time:

\[ \delta_t = r_t + \gamma (1-d_t) V(s_{t+1}) - V(s_t) \]
\[ A_t = \delta_t + \gamma\lambda(1-d_t) A_{t+1} \]

where \(d_t\in\{0,1\}\) is the done flag.

def compute_gae(
    rewards: torch.Tensor,   # (T, N)
    dones: torch.Tensor,     # (T, N) float32 {0,1}
    values: torch.Tensor,    # (T, N)
    last_values: torch.Tensor,  # (N,)
    gamma: float,
    gae_lambda: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    T, N = rewards.shape
    advantages = torch.zeros((T, N), device=rewards.device, dtype=torch.float32)
    last_adv = torch.zeros((N,), device=rewards.device, dtype=torch.float32)

    for t in reversed(range(T)):
        mask = 1.0 - dones[t]
        next_values = last_values if t == T - 1 else values[t + 1]
        delta = rewards[t] + gamma * mask * next_values - values[t]
        last_adv = delta + gamma * gae_lambda * mask * last_adv
        advantages[t] = last_adv

    returns = advantages + values
    return advantages, returns

5) Training loop (A2C)#

Key design choices in this minimal implementation:

  • Vectorized envs (n_envs) to match A2C’s synchronous batching.

  • Rollout buffer of shape (n_steps, n_envs, ...).

  • Compute GAE + bootstrapped returns.

  • Single gradient update per rollout (no replay buffer, no off-policy corrections).

We also record:

  • episodic return (score) whenever any env finishes an episode

  • actor loss, critic loss, entropy, explained variance (optional diagnostic)

def explained_variance(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    var_y = np.var(y_true)
    if var_y < 1e-12:
        return float('nan')
    return 1.0 - float(np.var(y_true - y_pred) / var_y)


def train_a2c(
    env_id: str,
    seed: int,
    device: torch.device,
    n_envs: int,
    n_steps: int,
    total_timesteps: int,
    gamma: float,
    gae_lambda: float,
    ent_coef: float,
    vf_coef: float,
    lr: float,
    max_grad_norm: float,
    rmsprop_eps: float,
    hidden_sizes: tuple[int, int],
    normalize_advantage: bool,
    log_every_updates: int = 50,
    ):
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = make_vec_env(env_id, n_envs, seed)
    obs_space = env.single_observation_space
    act_space = env.single_action_space

    obs_dim = int(np.prod(obs_space.shape))
    n_actions = int(act_space.n)

    model = ActorCritic(obs_dim, n_actions, hidden_sizes=hidden_sizes).to(device)
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, eps=rmsprop_eps)

    # Rollout buffers
    obs_buf = torch.zeros((n_steps, n_envs, obs_dim), device=device, dtype=torch.float32)
    act_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.int64)
    rew_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
    done_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)
    val_buf = torch.zeros((n_steps, n_envs), device=device, dtype=torch.float32)

    obs, _ = env.reset(seed=[seed + i for i in range(n_envs)])

    # Episode tracking across vector envs
    ep_returns_running = np.zeros((n_envs,), dtype=np.float32)
    ep_lengths_running = np.zeros((n_envs,), dtype=np.int32)
    ep_returns: list[float] = []
    ep_lengths: list[int] = []

    updates = total_timesteps // (n_envs * n_steps)
    history_updates: list[dict] = []
    last_adv_flat = None

    t0 = time.time()
    global_step = 0
    model.train()

    for update in range(1, updates + 1):
        # --- Collect rollout ---
        for t in range(n_steps):
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
            obs_buf[t] = obs_t

            with torch.no_grad():
                logits, values = model(obs_t)
                actions, _, _ = sample_actions_and_logp(logits)

            act_buf[t] = actions
            val_buf[t] = values

            next_obs, rewards, terminated, truncated, _ = env.step(actions.cpu().numpy())
            dones = np.logical_or(terminated, truncated)

            rew_buf[t] = torch.as_tensor(rewards, dtype=torch.float32, device=device)
            done_buf[t] = torch.as_tensor(dones, dtype=torch.float32, device=device)

            # Episode bookkeeping
            ep_returns_running += rewards
            ep_lengths_running += 1
            for i in np.where(dones)[0]:
                ep_returns.append(float(ep_returns_running[i]))
                ep_lengths.append(int(ep_lengths_running[i]))
                ep_returns_running[i] = 0.0
                ep_lengths_running[i] = 0

            obs = next_obs
            global_step += n_envs

        # Bootstrap value from last observation
        with torch.no_grad():
            obs_last = torch.as_tensor(obs, dtype=torch.float32, device=device)
            _, last_values = model(obs_last)  # (N,)

        advantages, returns = compute_gae(
            rewards=rew_buf,
            dones=done_buf,
            values=val_buf,
            last_values=last_values,
            gamma=gamma,
            gae_lambda=gae_lambda,
        )

        # Flatten (T, N, ...) -> (T*N, ...)
        b_obs = obs_buf.reshape(-1, obs_dim)
        b_act = act_buf.reshape(-1)
        b_adv = advantages.reshape(-1)
        b_ret = returns.reshape(-1)

        if normalize_advantage:
            b_adv = (b_adv - b_adv.mean()) / (b_adv.std() + 1e-8)

        # --- Compute losses ---
        logits, values_pred = model(b_obs)
        log_probs = F.log_softmax(logits, dim=-1)
        probs = log_probs.exp()

        b_logp = log_probs.gather(1, b_act.unsqueeze(1)).squeeze(1)
        entropy = -(probs * log_probs).sum(dim=-1).mean()

        actor_loss = -(b_logp * b_adv.detach()).mean()
        critic_loss = 0.5 * F.mse_loss(values_pred, b_ret.detach())

        loss = actor_loss + vf_coef * critic_loss - ent_coef * entropy

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        last_adv_flat = b_adv.detach().cpu().numpy()

        # Diagnostics
        y_true = b_ret.detach().cpu().numpy()
        y_pred = values_pred.detach().cpu().numpy()

        mean_ep_return = float(np.mean(ep_returns[-RETURN_SMOOTHING_WINDOW:])) if len(ep_returns) else float('nan')

        history_updates.append(
            dict(
                update=update,
                timesteps=global_step,
                actor_loss=float(actor_loss.detach().cpu().item()),
                critic_loss=float(critic_loss.detach().cpu().item()),
                entropy=float(entropy.detach().cpu().item()),
                explained_variance=explained_variance(y_true, y_pred),
                episodes=len(ep_returns),
                mean_return_window=mean_ep_return,
            )
        )

        if update % log_every_updates == 0 or update == 1 or update == updates:
            elapsed = time.time() - t0
            print(
                f"update {update:>4d}/{updates} | steps {global_step:>7d} | episodes {len(ep_returns):>5d} | "
                f"mean_return@{RETURN_SMOOTHING_WINDOW} {mean_ep_return:>7.1f} | "
                f"loss {float(loss.detach().cpu()):>8.4f} | {elapsed:>6.1f}s"
            )

    env.close()
    hist_df = pd.DataFrame(history_updates)
    return model, hist_df, np.array(ep_returns, dtype=np.float32), np.array(ep_lengths, dtype=np.int32), last_adv_flat
model, hist_df, ep_returns, ep_lengths, last_adv_flat = train_a2c(
    env_id=ENV_ID,
    seed=SEED,
    device=DEVICE,
    n_envs=N_ENVS,
    n_steps=N_STEPS,
    total_timesteps=TOTAL_TIMESTEPS,
    gamma=GAMMA,
    gae_lambda=GAE_LAMBDA,
    ent_coef=ENT_COEF,
    vf_coef=VF_COEF,
    lr=LR,
    max_grad_norm=MAX_GRAD_NORM,
    rmsprop_eps=RMSPROP_EPS,
    hidden_sizes=HIDDEN_SIZES,
    normalize_advantage=NORMALIZE_ADVANTAGE,
    log_every_updates=LOG_EVERY_UPDATES,
)

hist_df.tail()
update    1/750 | steps      40 | episodes     0 | mean_return@50     nan | loss   2.6692 |    0.0s
update   50/750 | steps    2000 | episodes    82 | mean_return@50    24.9 | loss   1.6804 |    0.2s
update  100/750 | steps    4000 | episodes   146 | mean_return@50    31.8 | loss   7.0746 |    0.5s
update  150/750 | steps    6000 | episodes   211 | mean_return@50    33.0 | loss  12.9465 |    0.8s
update  200/750 | steps    8000 | episodes   266 | mean_return@50    37.3 | loss  27.3818 |    1.0s
update  250/750 | steps   10000 | episodes   309 | mean_return@50    42.7 | loss   7.7835 |    1.3s
update  300/750 | steps   12000 | episodes   343 | mean_return@50    54.5 | loss   1.4184 |    1.5s
update  350/750 | steps   14000 | episodes   371 | mean_return@50    65.0 | loss  19.3684 |    1.8s
update  400/750 | steps   16000 | episodes   396 | mean_return@50    71.7 | loss   1.0661 |    2.1s
update  450/750 | steps   18000 | episodes   431 | mean_return@50    67.6 | loss  13.2570 |    2.4s
update  500/750 | steps   20000 | episodes   457 | mean_return@50    65.4 | loss   1.0435 |    2.7s
update  550/750 | steps   22000 | episodes   487 | mean_return@50    71.2 | loss   1.0582 |    2.9s
update  600/750 | steps   24000 | episodes   514 | mean_return@50    67.5 | loss   0.9999 |    3.0s
update  650/750 | steps   26000 | episodes   551 | mean_return@50    52.4 | loss   1.1563 |    3.1s
update  700/750 | steps   28000 | episodes   574 | mean_return@50    63.4 | loss   2.1548 |    3.3s
update  750/750 | steps   30000 | episodes   600 | mean_return@50    79.5 | loss   1.1326 |    3.4s
update timesteps actor_loss critic_loss entropy explained_variance episodes mean_return_window
745 746 29840 -0.141999 2.279103 0.612243 -0.273748 600 79.48
746 747 29880 -0.164299 2.443412 0.641806 -0.466463 600 79.48
747 748 29920 -0.135936 3.744340 0.590112 -1.532006 600 79.48
748 749 29960 -0.104616 3.020100 0.607867 -0.825297 600 79.48
749 750 30000 -0.107253 2.492551 0.637384 -0.567428 600 79.48

6) Plot: score (return) per episode#

CartPole gives reward \(+1\) per time step, so episode return = episode length (up to 500).

episodes = np.arange(1, len(ep_returns) + 1)

roll_mean = pd.Series(ep_returns).rolling(RETURN_SMOOTHING_WINDOW).mean().to_numpy()

fig = go.Figure()
fig.add_trace(go.Scatter(x=episodes, y=ep_returns, mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=episodes, y=roll_mean, mode='lines', name=f'mean@{RETURN_SMOOTHING_WINDOW}', line=dict(width=3)))
fig.update_layout(
    title='A2C on CartPole-v1 — score (return) per episode',
    xaxis_title='episode',
    yaxis_title='return',
)
fig.show()

7) Plot: training diagnostics (losses, entropy, explained variance)#

  • Actor loss becomes more negative when advantages are consistently positive for sampled actions.

  • Critic loss should generally decrease as the value function fits the returns.

  • Entropy typically decreases as the policy becomes more confident.

  • Explained variance (rough critic diagnostic) near 1 is good; near 0 means the critic explains little.

fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Actor loss", "Critic loss", "Entropy", "Explained variance"),
)

fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['actor_loss'], name='actor_loss'), row=1, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['critic_loss'], name='critic_loss'), row=1, col=2)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['entropy'], name='entropy'), row=2, col=1)
fig.add_trace(go.Scatter(x=hist_df['timesteps'], y=hist_df['explained_variance'], name='explained_variance'), row=2, col=2)

fig.update_layout(height=700, title='A2C training diagnostics', showlegend=False)
fig.update_xaxes(title_text='timesteps')
fig.show()

8) Plot: advantage distribution (last update)#

A2C pushes up the log-probability of actions with positive advantage and pushes down those with negative advantage.

fig = px.histogram(
    x=last_adv_flat,
    nbins=60,
    title='Advantage histogram (last update)',
)
fig.update_layout(xaxis_title='advantage', yaxis_title='count')
fig.show()

9) Visualize the learned policy + value function (2D slice)#

CartPole states are 4D:

\[ s = (x, \dot{x}, \theta, \dot{\theta}). \]

To visualize something, we take a 2D slice over pole angle \(\theta\) and pole angular velocity \(\dot{\theta}\), while fixing \(x=0\) and \(\dot{x}=0\).

  • Left plot: \(\pi(a=1\mid s)\) (probability of pushing right)

  • Right plot: \(V(s)\) (critic estimate)

@torch.no_grad()
def policy_value_slice(model: nn.Module, device: torch.device, grid_n: int = 70):
    model.eval()
    angles = np.linspace(-0.21, 0.21, grid_n)  # roughly CartPole angle limits
    ang_vels = np.linspace(-3.0, 3.0, grid_n)

    theta, theta_dot = np.meshgrid(angles, ang_vels)
    states = np.zeros((grid_n * grid_n, 4), dtype=np.float32)
    states[:, 2] = theta.ravel()
    states[:, 3] = theta_dot.ravel()

    obs_t = torch.as_tensor(states, dtype=torch.float32, device=device)
    logits, values = model(obs_t)
    probs = policy_action_probs(logits)
    p_right = probs[:, 1].reshape(grid_n, grid_n).cpu().numpy()
    v = values.reshape(grid_n, grid_n).cpu().numpy()

    return angles, ang_vels, p_right, v


angles, ang_vels, p_right, v = policy_value_slice(model, DEVICE)

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Policy: P(push right)", "Critic: V(s)"),
)

fig.add_trace(
    go.Heatmap(
        x=angles,
        y=ang_vels,
        z=p_right,
        colorscale='RdBu',
        zmin=0.0,
        zmax=1.0,
        colorbar=dict(title='P(right)'),
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Heatmap(
        x=angles,
        y=ang_vels,
        z=v,
        colorscale='Viridis',
        colorbar=dict(title='V(s)'),
    ),
    row=1,
    col=2,
)

fig.update_layout(
    height=420,
    title='Learned policy/value on a 2D state slice (x=0, xdot=0)',
)
fig.update_xaxes(title_text='pole angle θ', row=1, col=1)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=1)
fig.update_xaxes(title_text='pole angle θ', row=1, col=2)
fig.update_yaxes(title_text='pole angular velocity θdot', row=1, col=2)
fig.show()

10) Quick evaluation (deterministic actions)#

We evaluate by taking the greedy action \(\arg\max_a \pi(a\mid s)\).

@torch.no_grad()
def evaluate(model: nn.Module, env_id: str, n_episodes: int = 10, seed: int = 0):
    env = gym.make(env_id)
    returns = []
    for ep in range(n_episodes):
        obs, _ = env.reset(seed=seed + ep)
        done = False
        ret = 0.0
        while not done:
            obs_t = torch.as_tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            logits, _ = model(obs_t)
            action = int(torch.argmax(logits, dim=-1).item())

            obs, reward, terminated, truncated, _ = env.step(action)
            done = bool(terminated or truncated)
            ret += float(reward)
        returns.append(ret)
    env.close()
    return np.array(returns, dtype=np.float32)


model.eval()
eval_returns = evaluate(model, ENV_ID, n_episodes=10, seed=SEED + 1000)
print('eval returns:', eval_returns)
print('mean ± std:', float(eval_returns.mean()), '±', float(eval_returns.std()))
eval returns: [474. 500. 319. 245. 374. 419. 211. 354. 247. 292.]
mean ± std: 343.5 ± 93.85440826416016

11) Pitfalls + diagnostics#

  • On-policy constraint: A2C uses data from the current policy. If you reuse old experience without correction, it becomes biased.

  • Done handling: You must stop bootstrapping across episode boundaries. Here we treat terminated OR truncated as terminal for simplicity.

  • Entropy coefficient: Too high keeps the policy random; too low can collapse exploration early.

  • Critic collapse: If the critic is too weak/strong relative to the actor, learning can become unstable.

  • Parallel envs matter: With too few envs you get higher-variance updates.

Good quick checks:

  • returns increase over time

  • entropy decreases gradually (not instantly)

  • critic loss decreases and explained variance improves

12) Exercises#

  1. Change \(\lambda\) in GAE (e.g. 0.9) and compare learning curves.

  2. Swap RMSprop for Adam and compare stability.

  3. Implement continuous actions by outputting a Gaussian policy (mean + log-std) and testing on Pendulum-v1.

  4. Add a learning-rate schedule.

  5. Add observation normalization and compare speed.

13) Stable-Baselines3 A2C reference implementation (web research)#

Stable-Baselines3 (SB3) includes an A2C implementation.

  • Docs page: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html

Minimal usage#

from stable_baselines3 import A2C
import gymnasium as gym

env = gym.make("CartPole-v1")
model = A2C(
    policy="MlpPolicy",
    env=env,
    learning_rate=7e-4,
    n_steps=5,
    gamma=0.99,
    gae_lambda=1.0,
    ent_coef=0.0,
    vf_coef=0.5,
    max_grad_norm=0.5,
    rms_prop_eps=1e-5,
    use_rms_prop=True,
    normalize_advantage=False,
)
model.learn(total_timesteps=200_000)

SB3 A2C hyperparameters (signature + meaning)#

From the SB3 docs, the constructor signature is:

A2C(policy, env, learning_rate=0.0007, n_steps=5, gamma=0.99, gae_lambda=1.0,
    ent_coef=0.0, vf_coef=0.5, max_grad_norm=0.5, rms_prop_eps=1e-05,
    use_rms_prop=True, use_sde=False, sde_sample_freq=-1,
    rollout_buffer_class=None, rollout_buffer_kwargs=None,
    normalize_advantage=False, stats_window_size=100, tensorboard_log=None,
    policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)

Parameter meanings (SB3 docs):

  • policy: policy class (e.g. MlpPolicy, CnnPolicy)

  • env: environment (Gym env, VecEnv, or registered env id string)

  • learning_rate: float or schedule

  • n_steps: rollout length per env (batch size = n_steps * n_env)

  • gamma: discount factor

  • gae_lambda: bias/variance trade-off for GAE; 1.0 equals classic advantage

  • ent_coef: entropy coefficient

  • vf_coef: value loss coefficient

  • max_grad_norm: gradient clipping threshold

  • rms_prop_eps: RMSprop epsilon

  • use_rms_prop: use RMSprop (default) vs Adam

  • use_sde: generalized State Dependent Exploration (gSDE)

  • sde_sample_freq: resample gSDE noise every n steps (-1 = only at rollout start)

  • rollout_buffer_class: custom rollout buffer class

  • rollout_buffer_kwargs: kwargs for rollout buffer

  • normalize_advantage: normalize advantages

  • stats_window_size: episodes window for logging averages

  • tensorboard_log: tensorboard log dir

  • policy_kwargs: kwargs for policy network/architecture

  • verbose: verbosity level

  • seed: random seed

  • device: cpu, cuda, or auto

  • _init_setup_model: build the network immediately

References#

  • Mnih et al. (2016), Asynchronous Methods for Deep Reinforcement Learning (A3C)

  • Schulman et al. (2016), High-Dimensional Continuous Control Using Generalized Advantage Estimation

  • Stable-Baselines3 A2C docs: https://stable-baselines3.readthedocs.io/en/master/modules/a2c.html